# hhmm_lib.py
import numpy as np
import torch
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional, Any

# ---------- numerics ----------
LOGTINY = -1e12
EPS = 1e-8

def logsumexp(a, axis=None, keepdims=False):
    a = np.asarray(a)
    m = np.max(a, axis=axis, keepdims=True)
    out = m + np.log(np.maximum(EPS, np.sum(np.exp(a - m), axis=axis, keepdims=True)))
    if not keepdims:
        out = np.squeeze(out, axis=axis)
    return out

# ---------- semantic tag space (anchored top) ----------
CANON_TAGS = [
    "final_answer",                 # 0
    "setup_and_retrieval",          # 1
    "analysis_and_computation",     # 2
    "uncertainty_and_verification", # 3
]
CANON_TAG2ID = {t: i for i, t in enumerate(CANON_TAGS)}

def _unwrap1(x: Any) -> Any:
    if isinstance(x, (tuple, list)) and len(x) > 0:
        return x[0]
    return x

def _label_str_to_canon_id(s: str) -> int:
    if s in CANON_TAG2ID:
        return CANON_TAG2ID[s]
    raise KeyError(f"Unknown label string: {s!r}")

def coerce_labels_to_ids(labels: List[Any]) -> List[int]:
    """
    Accepts items like:
      - int/float (0..3),
      - str (canonical 4-tag),
      - 1-item tuple/list wrapping one of the above,
      - (sentence_text, label) pairs.
    Returns List[int] in {0,1,2,3}.
    """
    out: List[int] = []
    for v in labels:
        if isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[1], (str, int, float, np.integer, np.floating)):
            v = v[1]
        else:
            v = _unwrap1(v)
        if isinstance(v, str):
            out.append(_label_str_to_canon_id(v))
        elif isinstance(v, (int, np.integer, float, np.floating)):
            iv = int(v)
            if iv < 0 or iv > 3:
                raise ValueError(f"Numeric label {v} out of range. Must be in {{0,1,2,3}}.")
            out.append(iv)
        else:
            raise TypeError(f"Unsupported label type: {type(v)} value={v}")
    return out

def extract_labels_from_sentences_with_labels(swls: List[Any]) -> List[int]:
    if not isinstance(swls, list):
        raise TypeError("sentences_with_labels must be a list.")
    return coerce_labels_to_ids(swls)

# ---------- data adapters ----------
def load_pt_records(pt_path: str):
    blob = torch.load(pt_path, map_location="cpu")
    return blob["records"]

def build_top_sequences(records) -> List[Dict]:
    """
    Returns list of sequences; each sequence:
      {
        "steps": [ np.array(L, D), ... ]            # per step: layer-by-D vectors
        "sentences_with_labels": List[(str, str)]   # REQUIRED for anchored training
      }
    """
    seqs: List[Dict] = []
    for r in records:
        hs_list = r.get("step_hidden_states", [])
        steps: List[np.ndarray] = []
        for H in hs_list:
            if H is None:
                continue
            if isinstance(H, torch.Tensor):
                arr = H.to(torch.float32).cpu().numpy()  # [L, D]
            else:
                arr = torch.as_tensor(H, dtype=torch.float32).cpu().numpy()
            steps.append(arr)
        if not steps:
            continue
        seq: Dict[str, Any] = {"steps": steps}
        if "sentences_with_labels" in r:
            seq["sentences_with_labels"] = r["sentences_with_labels"]
        elif "label_sequence_4cat" in r:
            # allow legacy key name
            seq["sentences_with_labels"] = r["label_sequence_4cat"]
        # else: anchoring will later raise if labels missing
        seqs.append(seq)
    return seqs

# ---------- bottom HMM (per-category) with diagonal Gaussian emissions ----------
@dataclass
class BottomHMMParams:
    startprob: np.ndarray      # [K]
    transmat: np.ndarray       # [K, K]
    means: np.ndarray          # [K, D]
    variances: np.ndarray      # [K, D]  # diag cov

def init_bottom_params(K: int, D: int, x_samples: np.ndarray, sticky=0.90, rs: Optional[np.random.RandomState]=None) -> BottomHMMParams:
    """
    Deterministic init via provided RandomState rs:
      - pick a small base pool
      - seed K means from it (with replacement)
      - set diagonal variances to global var + epsilon
    """
    rs = rs or np.random.RandomState(0)
    pool_n = min(10 * K, x_samples.shape[0])
    idx = rs.choice(x_samples.shape[0], size=pool_n, replace=False)
    base = x_samples[idx]                                  # [pool_n, D]
    means = base[rs.choice(base.shape[0], size=K, replace=True)]
    variances = np.tile(np.var(base, axis=0, ddof=1) + 1e-2, (K, 1))
    startprob = np.ones(K, dtype=np.float64) / max(K, 1)
    transmat = np.full((K, K), (1.0 - sticky) / max(K - 1, 1), dtype=np.float64)
    np.fill_diagonal(transmat, sticky)
    return BottomHMMParams(startprob, transmat, means, variances)

def log_gaussian_diag(x: np.ndarray, means: np.ndarray, variances: np.ndarray) -> np.ndarray:
    """
    x: [T, D]; means/variances: [K, D]
    returns log-likelihood per state: [T, K]
    """
    T, D = x.shape
    inv = 1.0 / np.maximum(variances, 1e-6)
    logdet = 0.5 * np.sum(np.log(np.maximum(variances, 1e-6)), axis=1)   # [K]
    diff = x[:, None, :] - means[None, :, :]
    quad = 0.5 * np.sum(diff * diff * inv[None, :, :], axis=2)           # [T, K]
    const = 0.5 * D * np.log(2 * np.pi)
    return -(const + quad + logdet[None, :])

def fb_bottom(x: np.ndarray, params: BottomHMMParams):
    """Forward-backward for one step's layer sequence under one category's bottom HMM."""
    start = np.log(np.maximum(EPS, params.startprob))
    trans = np.log(np.maximum(EPS, params.transmat))
    B = log_gaussian_diag(x, params.means, params.variances)  # [T, K]

    T, K = B.shape
    alpha = np.full((T, K), LOGTINY)
    alpha[0] = start + B[0]
    for t in range(1, T):
        alpha[t] = B[t] + logsumexp(alpha[t-1][:, None] + trans, axis=0)
    logZ = logsumexp(alpha[-1], axis=0)

    beta = np.full((T, K), LOGTINY)
    beta[-1] = 0.0
    for t in range(T - 2, -1, -1):
        beta[t] = logsumexp(trans + (B[t + 1] + beta[t + 1])[None, :], axis=1)

    gamma = alpha + beta - logZ
    gamma = np.exp(gamma)  # [T, K]
    xi = np.zeros((K, K), dtype=np.float64)
    for t in range(T - 1):
        log_psi = alpha[t][:, None] + trans + B[t + 1][None, :] + beta[t + 1][None, :] - logZ
        xi += np.exp(log_psi)
    return logZ, gamma, xi

def mstep_bottom(accum, min_var=1e-3) -> BottomHMMParams:
    """accum keys: 'start' [K], 'trans' [K,K], 'sum_w' [K], 'sum_x' [K,D], 'sum_x2' [K,D]"""
    start = accum["start"]
    trans = accum["trans"]
    sum_w = accum["sum_w"]
    sum_x = accum["sum_x"]
    sum_x2 = accum["sum_x2"]

    startprob = np.maximum(EPS, start)
    startprob = startprob / np.sum(startprob) if startprob.sum() > 0 else np.ones_like(startprob) / len(startprob)

    trans = np.maximum(EPS, trans)
    trans = trans / np.maximum(EPS, np.sum(trans, axis=1, keepdims=True))

    means = sum_x / np.maximum(EPS, sum_w)[:, None]
    variances = sum_x2 / np.maximum(EPS, sum_w)[:, None] - means * means
    variances = np.maximum(variances, min_var)
    return BottomHMMParams(startprob, trans, means, variances)

# ---------- top HMM (per-sequence steps) ----------
@dataclass
class TopHMMParams:
    startprob: np.ndarray   # [C]
    transmat: np.ndarray    # [C, C]

def init_top_params(C: int, sticky=0.80):
    startprob = np.ones(C, dtype=np.float64) / max(C, 1)
    trans = np.full((C, C), (1.0 - sticky) / max(C - 1, 1), dtype=np.float64)
    np.fill_diagonal(trans, sticky)
    return TopHMMParams(startprob, trans)

def fb_top(log_emissions: np.ndarray, params: TopHMMParams):
    """
    log_emissions: [T_steps, C]
    returns: gamma [T,C], xi [C,C], logZ
    """
    start = np.log(np.maximum(EPS, params.startprob))
    trans = np.log(np.maximum(EPS, params.transmat))
    T, C = log_emissions.shape
    alpha = np.full((T, C), LOGTINY)
    alpha[0] = start + log_emissions[0]
    for t in range(1, T):
        alpha[t] = log_emissions[t] + logsumexp(alpha[t - 1][:, None] + trans, axis=0)
    logZ = logsumexp(alpha[-1], axis=0)

    beta = np.full((T, C), LOGTINY)
    beta[-1] = 0.0
    for t in range(T - 2, -1, -1):
        beta[t] = logsumexp(trans + (log_emissions[t + 1] + beta[t + 1])[None, :], axis=1)

    gamma = np.exp(alpha + beta - logZ)
    xi = np.zeros((C, C), dtype=np.float64)
    for t in range(T - 1):
        log_psi = alpha[t][:, None] + trans + (log_emissions[t + 1] + beta[t + 1])[None, :] - logZ
        xi += np.exp(log_psi)
    return gamma, xi, logZ

# ---------- HHMM container ----------
@dataclass
class HHMM:
    C: int
    K: int
    D: int
    top: TopHMMParams
    bottom: List[BottomHMMParams]  # length C

def init_hhmm(C: int, K: int, D: int, sequences, rs: Optional[np.random.RandomState]=None) -> HHMM:
    rs = rs or np.random.RandomState(0)
    pool: List[np.ndarray] = []
    for seq in sequences:
        for x in seq["steps"]:
            pool.append(x)
    if not pool:
        raise ValueError("No steps to initialize from.")
    X = np.concatenate(pool, axis=0)  # [sum L, D]
    top = init_top_params(C)
    bottom = [init_bottom_params(K, D, X, sticky=0.90, rs=rs) for _ in range(C)]
    return HHMM(C, K, D, top, bottom)

# ---------- EM training (anchored top with semantic labels) ----------
def fit_hhmm_fixed_top(
    sequences: List[Dict],
    C: int,
    K: int,
    label_key: str = "sentences_with_labels",
    n_iter: int = 10,
    min_var: float = 1e-3,
    seed: int = 0,
    verbose: bool = True,
) -> HHMM:
    """
    Anchored HHMM: the top category at each step is FIXED by provided labels.
    We estimate:
      - Top startprob/transmat by counting over the fixed label paths.
      - Bottom HMM params per category via FB over layers, conditioned on that category.
    """
    # infer D
    D: Optional[int] = None
    for seq in sequences:
        for x in seq["steps"]:
            D = int(x.shape[1]); break
        if D is not None: break
    assert D is not None, "Empty sequences."

    rs = np.random.RandomState(seed)
    model = init_hhmm(C, K, D, sequences, rs)

    for it in range(1, n_iter + 1):
        top_start_counts = np.zeros(C, dtype=np.float64)
        top_trans_counts = np.zeros((C, C), dtype=np.float64)
        bottom_acc = [{
            "start": np.zeros(K),
            "trans": np.zeros((K, K)),
            "sum_w": np.zeros(K),
            "sum_x": np.zeros((K, D)),
            "sum_x2": np.zeros((K, D)),
        } for _ in range(C)]
        total_logZ = 0.0
        n_steps_total = 0

        for seq in sequences:
            steps = seq["steps"]

            labels_raw = seq.get(label_key, None)
            if labels_raw is None:
                raise RuntimeError(
                    f"Anchored training requires labels: missing '{label_key}' in a sequence."
                )
            y = coerce_labels_to_ids(labels_raw)  # 0..3
            Tsteps = min(len(steps), len(y))
            if Tsteps == 0:
                continue
            n_steps_total += Tsteps

            # top counts (hard labels)
            top_start_counts[y[0]] += 1.0
            for t in range(Tsteps - 1):
                top_trans_counts[y[t], y[t + 1]] += 1.0

            # bottom E-step: per-step, condition on fixed category c=y[t]
            for t in range(Tsteps):
                c = int(y[t])
                x = steps[t]
                logZ, gamma_r, xi_r = fb_bottom(x, model.bottom[c])
                total_logZ += logZ

                bottom_acc[c]["start"] += gamma_r[0]
                bottom_acc[c]["trans"] += xi_r
                gsum = np.sum(gamma_r, axis=0)
                bottom_acc[c]["sum_w"] += gsum
                bottom_acc[c]["sum_x"] += (gamma_r.T @ x)
                bottom_acc[c]["sum_x2"] += (gamma_r.T @ (x * x))

        # M-step for top
        startprob = np.maximum(EPS, top_start_counts)
        startprob = startprob / np.sum(startprob) if startprob.sum() > 0 else np.ones(C) / C

        transmat = np.maximum(EPS, top_trans_counts)
        row_sums = np.sum(transmat, axis=1, keepdims=True)
        transmat = np.divide(transmat, np.maximum(EPS, row_sums), out=np.ones_like(transmat)/C, where=row_sums > 0)
        model.top = TopHMMParams(startprob, transmat)

        # M-step for bottoms
        for c in range(C):
            model.bottom[c] = mstep_bottom(bottom_acc[c], min_var=min_var)

        if verbose:
            avg_ll = total_logZ / max(1, n_steps_total)
            print(f"[HHMM-anchored] iter {it:02d}  avg step loglik = {avg_ll:.4f}", flush=True)

    return model

# ---------- decoding (anchored only) ----------
def viterbi_bottom(x: np.ndarray, params: BottomHMMParams) -> Tuple[float, np.ndarray]:
    start = np.log(np.maximum(EPS, params.startprob))
    trans = np.log(np.maximum(EPS, params.transmat))
    B = log_gaussian_diag(x, params.means, params.variances)  # [T, K]
    T, K = B.shape
    dp = np.full((T, K), LOGTINY); ptr = np.full((T, K), -1, dtype=int)
    dp[0] = start + B[0]
    for t in range(1, T):
        scores = dp[t - 1][:, None] + trans   # [K,K]
        ptr[t] = np.argmax(scores, axis=0)
        dp[t] = B[t] + np.max(scores, axis=0)
    logZ = np.max(dp[-1])
    z = np.zeros(T, dtype=int)
    z[-1] = int(np.argmax(dp[-1]))
    for t in range(T - 2, -1, -1):
        z[t] = ptr[t + 1, z[t + 1]]
    return logZ, z

def decode_hhmm_anchored(sequences: List[Dict], model: HHMM, label_key: str = "sentences_with_labels"):
    """
    Anchored decode:
      Uses provided semantic labels as the top categories;
      decodes bottom regimes (Viterbi per step, conditioned on that category).
      Raises if labels are absent (since we run anchored-only).
    """
    out = []
    for seq in sequences:
        steps = seq["steps"]
        labels_raw = seq.get(label_key, None)
        if labels_raw is None:
            raise RuntimeError(f"Anchored decode requires labels: missing '{label_key}'.")
        y = coerce_labels_to_ids(labels_raw)
        T = min(len(steps), len(y))
        cats = []
        regimes = []
        for t in range(T):
            c = int(y[t])
            x = steps[t]
            _, z_r = viterbi_bottom(x, model.bottom[c])
            cats.append(c)
            regimes.append(z_r.tolist())
        out.append({"best_categories": cats, "best_regimes_per_step": regimes})
    return out